{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to force tool calling behavior\n",
    "\n",
    ":::info Prerequisites\n",
    "\n",
    "This guide assumes familiarity with the following concepts:\n",
    "- [Chat models](/docs/concepts/#chat-models)\n",
    "- [LangChain Tools](/docs/concepts/#tools)\n",
    "- [How to use a model to call tools](/docs/how_to/tool_calling)\n",
    ":::\n",
    "\n",
    "In order to force our LLM to spelect a specific tool, we can use the `tool_choice` parameter to ensure certain behavior. First, let's define our model and tools:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "@tool\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Adds a and b.\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "@tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiplies a and b.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "tools = [add, multiply]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# | output: false\n",
    "# | echo: false\n",
    "\n",
    "%pip install -qU langchain langchain_openai\n",
    "\n",
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass()\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, we can force our tool to call the multiply tool by using the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_9cViskmLvPnHjXk9tbVla5HA', 'function': {'arguments': '{\"a\":2,\"b\":4}', 'name': 'Multiply'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 103, 'total_tokens': 112}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-095b827e-2bdd-43bb-8897-c843f4504883-0', tool_calls=[{'name': 'Multiply', 'args': {'a': 2, 'b': 4}, 'id': 'call_9cViskmLvPnHjXk9tbVla5HA'}], usage_metadata={'input_tokens': 103, 'output_tokens': 9, 'total_tokens': 112})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "llm_forced_to_multiply = llm.bind_tools(tools, tool_choice=\"Multiply\")\n",
    "llm_forced_to_multiply.invoke(\"what is 2 + 4\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Even if we pass it something that doesn't require multiplcation - it will still call the tool!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also just force our tool to select at least one of our tools by passing in the \"any\" (or \"required\" which is OpenAI specific) keyword to the `tool_choice` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_mCSiJntCwHJUBfaHZVUB2D8W', 'function': {'arguments': '{\"a\":1,\"b\":2}', 'name': 'Add'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 94, 'total_tokens': 109}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-28f75260-9900-4bed-8cd3-f1579abb65e5-0', tool_calls=[{'name': 'Add', 'args': {'a': 1, 'b': 2}, 'id': 'call_mCSiJntCwHJUBfaHZVUB2D8W'}], usage_metadata={'input_tokens': 94, 'output_tokens': 15, 'total_tokens': 109})"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "llm_forced_to_use_tool = llm.bind_tools(tools, tool_choice=\"any\")\n",
    "llm_forced_to_use_tool.invoke(\"What day is today?\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
